{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "id": "18zNHCLg5C3t"
   },
   "outputs": [],
   "source": [
    "!pip install datasets\n",
    "!pip install evaluate\n",
    "!pip install transformers[torch]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VzKzjJJb542a"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import glob\n",
    "import shutil\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from datasets import load_dataset, DatasetDict\n",
    "import evaluate\n",
    "from evaluate import evaluator\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer, pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FfOmOPtnuMOw"
   },
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"csv\", data_files=\"./topic_train.csv\")\n",
    "\n",
    "dataset = dataset[\"train\"]\n",
    "train_testvalid = dataset.train_test_split(test_size=0.3)\n",
    "test_valid = train_testvalid[\"test\"].train_test_split(test_size=0.5)\n",
    "dataset = DatasetDict({\n",
    "    \"train\": train_testvalid[\"train\"],\n",
    "    \"test\": test_valid[\"test\"],\n",
    "    \"valid\": test_valid[\"train\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7MdWEZVV87bL"
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"KB/bert-base-swedish-cased\")\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    return tokenizer(examples[\"headl_text\"], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_data = dataset.map(preprocess_function, batched=True)\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "57JqV4KPWH-t"
   },
   "outputs": [],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(\"KB/bert-base-swedish-cased\", num_labels=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uX30SRXX9BV9"
   },
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./results\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    num_train_epochs=8,\n",
    "    weight_decay=0.01,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_data[\"train\"],\n",
    "    eval_dataset=tokenized_data[\"valid\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9oxwZ850-HI-"
   },
   "outputs": [],
   "source": [
    "predictions = trainer.predict(tokenized_data[\"test\"])\n",
    "preds = np.argmax(predictions.predictions, axis=-1)\n",
    "recall = evaluate.load(\"recall\")\n",
    "f1 = evaluate.load(\"f1\")\n",
    "precision = evaluate.load(\"precision\")\n",
    "f1 = f1.compute(predictions=preds, references=predictions.label_ids, average='weighted')\n",
    "recall = recall.compute(predictions=preds, references=predictions.label_ids, average='weighted')\n",
    "precision = precision.compute(predictions=preds, references=predictions.label_ids, average='weighted')\n",
    "print(precision)\n",
    "print(recall)\n",
    "print(f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lECd5eTVY1Jc"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "conf_matrix = confusion_matrix(predictions.label_ids, preds)\n",
    "\n",
    "print(\"Confusion Matrix:\\n\", conf_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bTBeU1So84da"
   },
   "outputs": [],
   "source": [
    "# get disaggregated metrics\n",
    "preds = pd.DataFrame(preds)\n",
    "preds.rename(columns={0: 'pred'}, inplace=True)\n",
    "\n",
    "ref = pd.DataFrame(predictions.label_ids)\n",
    "ref.rename(columns={0: 'ref'}, inplace=True)\n",
    "\n",
    "preds = pd.get_dummies(preds, columns=['pred'], prefix='pred')\n",
    "ref = pd.get_dummies(ref, columns=['ref'], prefix='ref')\n",
    "\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score\n",
    "\n",
    "y_true = ref['ref_0']\n",
    "y_pred = preds['pred_0']\n",
    "\n",
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "precision = precision_score(y_true, y_pred)\n",
    "recall = recall_score(y_true, y_pred)\n",
    "f1 = f1_score(y_true, y_pred)\n",
    "balanced_acc = balanced_accuracy_score(y_true, y_pred)\n",
    "\n",
    "print(f\"Precision: {precision:.4f}\")\n",
    "print(f\"Recall: {recall:.4f}\")\n",
    "print(f\"Accuracy: {accuracy:.4f}\")\n",
    "print(f\"Balanced Accuracy: {balanced_acc:.4f}\")\n",
    "print(f\"F1 Score: {f1:.4f}\")\n",
    "\n",
    "y_true = ref['ref_1']\n",
    "y_pred = preds['pred_1']\n",
    "\n",
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "precision = precision_score(y_true, y_pred)\n",
    "recall = recall_score(y_true, y_pred)\n",
    "f1 = f1_score(y_true, y_pred)\n",
    "balanced_acc = balanced_accuracy_score(y_true, y_pred)\n",
    "\n",
    "print(f\"Precision: {precision:.4f}\")\n",
    "print(f\"Recall: {recall:.4f}\")\n",
    "print(f\"Accuracy: {accuracy:.4f}\")\n",
    "print(f\"Balanced Accuracy: {balanced_acc:.4f}\")\n",
    "print(f\"F1 Score: {f1:.4f}\")\n",
    "\n",
    "y_true = ref['ref_2']\n",
    "y_pred = preds['pred_2']\n",
    "\n",
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "precision = precision_score(y_true, y_pred)\n",
    "recall = recall_score(y_true, y_pred)\n",
    "f1 = f1_score(y_true, y_pred)\n",
    "balanced_acc = balanced_accuracy_score(y_true, y_pred)\n",
    "\n",
    "print(f\"Precision: {precision:.4f}\")\n",
    "print(f\"Recall: {recall:.4f}\")\n",
    "print(f\"Accuracy: {accuracy:.4f}\")\n",
    "print(f\"Balanced Accuracy: {balanced_acc:.4f}\")\n",
    "print(f\"F1 Score: {f1:.4f}\")\n",
    "\n",
    "y_true = ref['ref_3']\n",
    "y_pred = preds['pred_3']\n",
    "\n",
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "precision = precision_score(y_true, y_pred)\n",
    "recall = recall_score(y_true, y_pred)\n",
    "f1 = f1_score(y_true, y_pred)\n",
    "balanced_acc = balanced_accuracy_score(y_true, y_pred)\n",
    "\n",
    "print(f\"Precision: {precision:.4f}\")\n",
    "print(f\"Recall: {recall:.4f}\")\n",
    "print(f\"Accuracy: {accuracy:.4f}\")\n",
    "print(f\"Balanced Accuracy: {balanced_acc:.4f}\")\n",
    "print(f\"F1 Score: {f1:.4f}\")\n",
    "\n",
    "y_true = ref['ref_4']\n",
    "y_pred = preds['pred_4']\n",
    "\n",
    "accuracy = accuracy_score(y_true, y_pred)\n",
    "precision = precision_score(y_true, y_pred)\n",
    "recall = recall_score(y_true, y_pred)\n",
    "f1 = f1_score(y_true, y_pred)\n",
    "balanced_acc = balanced_accuracy_score(y_true, y_pred)\n",
    "\n",
    "print(f\"Precision: {precision:.4f}\")\n",
    "print(f\"Recall: {recall:.4f}\")\n",
    "print(f\"Accuracy: {accuracy:.4f}\")\n",
    "print(f\"Balanced Accuracy: {balanced_acc:.4f}\")\n",
    "print(f\"F1 Score: {f1:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AbJlg1qWfzYL"
   },
   "outputs": [],
   "source": [
    "trainer.save_model(\"./fine_tuned_model_topic\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "esTCXaerzkR_"
   },
   "outputs": [],
   "source": [
    "# get articles for prediction part 1\n",
    "articles = pd.read_csv('./all_articles_for_pred_1.csv')\n",
    "\n",
    "# tranformers wants a text list as input\n",
    "text = articles['headl_text'].values.tolist()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"KB/bert-base-swedish-cased\")\n",
    "\n",
    "# load model\n",
    "fine_tuned_model = AutoModelForSequenceClassification.from_pretrained(\"./fine_tuned_model_topic\", num_labels=5)\n",
    "\n",
    "# set up and run pipeline for inference\n",
    "clf = pipeline(task=\"text-classification\", model=fine_tuned_model, tokenizer=tokenizer, truncation=True, device=0, max_length=512)\n",
    "answer = clf(text)\n",
    "\n",
    "# export\n",
    "answer = pd.DataFrame(answer)\n",
    "predictions = pd.concat([articles.reset_index(drop=True), answer], axis=1)\n",
    "predictions = predictions[[\"id\", \"label\", \"score\"]]\n",
    "predictions.to_csv('./predictions_topic_1.csv', sep=',', index=False, encoding='utf-8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZqqKEpTFeN8g"
   },
   "outputs": [],
   "source": [
    "# get articles for prediction part 2\n",
    "articles = pd.read_csv('./all_articles_for_pred_2.csv')\n",
    "\n",
    "# tranformers wants a text list as input\n",
    "text = articles['headl_text'].values.tolist()\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"KB/bert-base-swedish-cased\")\n",
    "\n",
    "# load model\n",
    "fine_tuned_model = AutoModelForSequenceClassification.from_pretrained(\"./fine_tuned_model_topic\", num_labels=5)\n",
    "\n",
    "# set up and run pipeline for inference\n",
    "clf = pipeline(task=\"text-classification\", model=fine_tuned_model, tokenizer=tokenizer, truncation=True, device=0, max_length=512)\n",
    "answer = clf(text)\n",
    "\n",
    "# export\n",
    "answer = pd.DataFrame(answer)\n",
    "predictions = pd.concat([articles.reset_index(drop=True), answer], axis=1)\n",
    "predictions = predictions[[\"id\", \"label\", \"score\"]]\n",
    "predictions.to_csv('./predictions_topic_2.csv', sep=',', index=False, encoding='utf-8')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyN8U5K3w1dhGB9skVKMvW9l",
   "gpuType": "L4",
   "machine_shape": "hm",
   "provenance": [
    {
     "file_id": "18WWqSHctzuD3CRIY_1krw64Oma7c1oAu",
     "timestamp": 1697881380924
    }
   ]
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
